{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Example: Distributing Large Embedding tables over TPU cores",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4SM2yBfGVJCk",
        "colab_type": "text"
      },
      "source": [
        "# Distributing Large Embedding tables over TPU cores\n",
        "\n",
        "Use Colab Cloud TPU\n",
        "\n",
        "<a href=\"https://cloud.google.com/tpu/\"><img valign=middle src=\"https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png\" width=\"50\"></a></h3>\n",
        "\n",
        "* On the main menu, click Runtime and select **Change runtime type**. Set \"TPU\" as the hardware accelerator.\n",
        "* The cell below makes sure you have access to a TPU on Colab."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2skQNymdVCRB",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import os\n",
        "assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iGsiv-BNWVBM",
        "colab_type": "text"
      },
      "source": [
        "## [RUNME] Install Colab TPU compatible PyTorch/TPU wheels and dependencies\n",
        "This may take up to ~2 minutes\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LXIEErvdWeGG",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "VERSION = \"nightly\"  #@param [\"1.5\" , \"20200325\", \"nightly\"]\n",
        "!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py\n",
        "!python pytorch-xla-env-setup.py --version $VERSION"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XO-lZMnnYm7P",
        "colab_type": "text"
      },
      "source": [
        "## Description and Objective\n",
        "\n",
        "The goal of this notebook is to illustrate a technique of distributing embedding tables over many TPU cores. This technique may come in handy for cases where the embedding table is very large, and won't fit on a single TPU device.\n",
        "\n",
        "We will use the popular [`fairseq`](https://github.com/pytorch-tpu/fairseq) repository to demonstrate how the training works, with parameters which would make the regular runs lead to an `HBM out-of-memory` error.\n",
        "\n",
        "### Explanation of the technique:\n",
        "\n",
        "The trick can be roughly summarized as follows:\n",
        "- Each core will own a slice of the embedding table, sliced by the embedding dimension.\n",
        "  - e.g. Core 1 will own dimensions 1-10, Core 2 will own 11-20, and so on.\n",
        "  - Every core will have the full list of entities being embedded.\n",
        "- During forward pass:\n",
        "  - Every core will share its input with other cores and end up with the full batch input.\n",
        "  - Then get the corresponding embedding dimensions for the full input.\n",
        "  - Do an all-gather and collect the other embedding dimensions from the other cores.\n",
        "  - At this point, every core has the full embeddings for the full input.\n",
        "  - Then each core will slice the full batch and end up with only the samples in the batch belonging to itself.\n",
        "  - Then the forward will resume normally.\n",
        "- During backward, it'll perform the opposite operations and each core will update the slice of the embedding table that it owns.\n",
        "\n",
        "## Setting up the task\n",
        "\n",
        "We will modify the translation workload [tutorial](https://cloud.google.com/tpu/docs/tutorials/transformer-pytorch) which uses `fairseq`'s Transformer model. Let's begin by first installing fairseq, and downloading the data."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mrwgZmZVgk6D",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "fairseq_path = '/tmp/fairseq'\n",
        "!git clone https://github.com/pytorch-tpu/fairseq.git -b tpu {fairseq_path}\n",
        "!pip install --editable {fairseq_path}\n",
        "!wget https://dl.fbaipublicfiles.com/fairseq/data/wmt18_en_de_bpej32k.zip\n",
        "!unzip wmt18_en_de_bpej32k.zip -d /tmp\n",
        "\n",
        "import sys\n",
        "sys.path.append(fairseq_path)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qxIRT7xRglJ5",
        "colab_type": "text"
      },
      "source": [
        "Now let's define `DistributedEmbedding` and the wrapper around the `fairseq_model` that will use it. We override the original model's embedding table, add the forward and backward methods described above, and add a couple of other methods to be used later."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xVKXBwl5eqlN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch_xla.core.xla_model as xm\n",
        "\n",
        "\n",
        "class DistributedEmbedding(nn.Module):\n",
        "\n",
        "  def __init__(self, vocab_size, embedding_size, world_size=None,\n",
        "               batch_dim=0):\n",
        "    super(DistributedEmbedding, self).__init__()\n",
        "    self._embedding_size = embedding_size\n",
        "    self._world_size = world_size\n",
        "    self._batch_dim = batch_dim\n",
        "    assert embedding_size % self._world_size == 0, \\\n",
        "        (\"For this example to work, please provide embedding size \"\n",
        "         \"a multiple of {}\".format(self._world_size))\n",
        "    self._sliced_emb_size = self._embedding_size // self._world_size\n",
        "    self.embeddings = nn.Embedding(vocab_size, self._sliced_emb_size)\n",
        "\n",
        "  @property\n",
        "  def _rank(self):\n",
        "    # We allow delaying the rank setting to allow module creation at global scope.\n",
        "    return xm.get_ordinal()\n",
        "\n",
        "  def _get_embedding_pad(self):\n",
        "    size = self._sliced_emb_size\n",
        "    return self._rank * size, (self._world_size - 1 - self._rank) * size\n",
        "\n",
        "  def forward(self, batch):\n",
        "    bsz = batch.size(self._batch_dim)\n",
        "    fullbatch = xm.all_gather(\n",
        "        batch.type(torch.float), dim=self._batch_dim).type(batch.dtype)\n",
        "    embeds = self.embeddings(fullbatch)\n",
        "    pembeds = xm.all_gather(embeds, dim=-1)\n",
        "    sliced_embeds = torch.narrow(pembeds, self._batch_dim, self._rank*bsz, bsz)\n",
        "    # We return both sub-batch's full embeddings and fullbatch's sliced embeddings\n",
        "    # The former is needed to do forward pass for the remainder of the model\n",
        "    # The latter is needed to do backward pass and update the embedding table.\n",
        "    return sliced_embeds.clone().detach().requires_grad_(True), embeds\n",
        "\n",
        "  def backward(self, fullbatch_slicedemb, grad):\n",
        "    # Gradient at this point has the full embedding dimensions\n",
        "    # and only contains info on the samples this core processed.\n",
        "    grad = xm.all_gather(grad, dim=self._batch_dim)\n",
        "    size = self._sliced_emb_size\n",
        "    sliced_grad = torch.narrow(grad, grad.ndim-1, self._rank * size, size)\n",
        "    fullbatch_slicedemb.backward(sliced_grad)\n",
        "\n",
        "\n",
        "class TransformerWithDistributedEmbeddings(nn.Module):\n",
        "\n",
        "  def __init__(self, model, emb_size, world_size):\n",
        "    super(TransformerWithDistributedEmbeddings, self).__init__()\n",
        "    self.model = model\n",
        "    self.dropout = self.model.encoder.dropout\n",
        "    self.embedding_size = emb_size\n",
        "    self._world_size = world_size\n",
        "    self._distribute_embeddings()\n",
        "\n",
        "  def _distribute_embeddings(self):\n",
        "    vocab_size = self.model.encoder.embed_tokens.weight.size(0)\n",
        "    self.padding_idx = self.model.encoder.embed_tokens.padding_idx\n",
        "    self.embedding = DistributedEmbedding(\n",
        "        vocab_size, self.embedding_size, world_size=self._world_size)\n",
        "    # We remove the original embedding layer.\n",
        "    self.model.encoder.embed_tokens = None\n",
        "\n",
        "  def init_emb_weights(self):\n",
        "    std = self.embedding_size\n",
        "    nn.init.normal_(self.embedding.embeddings.weight, mean=0, std=std**-0.5)\n",
        "    nn.init.constant_(self.embedding.embeddings.weight[self.padding_idx], 0)\n",
        "\n",
        "  def forward(self, **kwargs):\n",
        "    inputs = kwargs['src_tokens']\n",
        "    embedded_batch, emb_globalbatch_dimsliced = self.embedding(inputs)\n",
        "    x = F.dropout(\n",
        "        embedded_batch, p=self.dropout, training=self.model.training)\n",
        "    # Hack the encoder's `forward_embedding` method so that it returns what\n",
        "    #   was just computed in distributed fashion.\n",
        "    # This needs to return two values.\n",
        "    self.model.encoder.forward_embedding = lambda _: (x, None)\n",
        "    return self.model(**kwargs), embedded_batch, emb_globalbatch_dimsliced\n",
        "\n",
        "  def emb_backward(self, *args, **kwargs):\n",
        "    self.embedding.backward(*args, **kwargs)\n",
        "\n",
        "  def non_distr_params(self):\n",
        "    # Last parameter is the added distributed embedding table.\n",
        "    last_index = len(list(self.parameters())) - 1\n",
        "    for i, _ in enumerate(self.parameters()):\n",
        "      if i != last_index:\n",
        "        yield _"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5wahSapg6YAE",
        "colab_type": "text"
      },
      "source": [
        "Let's now create the `Namespace`, which `fairseq` uses to define the task, dataset, model and more."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HCcPvjX36i_7",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from fairseq import options\n",
        "\n",
        "# The following  leads to an HBM OOM w/ the regular way of embedding tokens.\n",
        "#   On v3-8:\n",
        "EMBEDDING_SIZE = 4096\n",
        "INPUT_SHAPES = [[64, 64],]\n",
        "#   On v2-8:\n",
        "EMBEDDING_SIZE = 2048\n",
        "INPUT_SHAPES = [[64, 64],]\n",
        "\n",
        "args = [\n",
        "  '/tmp/wmt18_en_de_bpej32k',\n",
        "  '--arch=transformer_wmt_en_de',\n",
        "  '--max-target-positions=64',\n",
        "  '--max-source-positions=64',\n",
        "  '--attention-dropout=0.0',\n",
        "  '--dropout=0.0',\n",
        "  '--no-progress-bar',\n",
        "  '--criterion=label_smoothed_cross_entropy',\n",
        "  '--source-lang=en',\n",
        "  '--target-lang=de',\n",
        "  '--lr-scheduler=inverse_sqrt',\n",
        "  '--min-lr=1e-09',\n",
        "  '--label-smoothing=0.1',\n",
        "  '--optimizer=adam',\n",
        "  '--adam-betas',\n",
        "  '(0.9, 0.98)',\n",
        "  '--warmup-init-lr=1e-07',\n",
        "  '--lr=0.0005',\n",
        "  '--warmup-updates=4000',\n",
        "  '--weight-decay=0.0',\n",
        "  '--no-save',\n",
        "  '--log-interval=20',\n",
        "  '--num-workers=1',\n",
        "  '--disable-validation',\n",
        "  '--max-epoch=1',\n",
        "  '--encoder-embed-dim={}'.format(EMBEDDING_SIZE),\n",
        "  '--decoder-embed-dim=512',\n",
        "]\n",
        "\n",
        "parser = options.get_training_parser()\n",
        "args = options.parse_args_and_arch(parser, input_args=args)\n",
        "args.input_shapes = INPUT_SHAPES\n",
        "args.use_gpu = False"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A7A1IMTFjw9f",
        "colab_type": "text"
      },
      "source": [
        "Now let's create the models. We're still at global scope, doing this will save host memory. Let's also define the training:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "G8ekUtnqe90X",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import math\n",
        "import torch_xla.distributed.xla_multiprocessing as xmp\n",
        "import torch_xla.debug.metrics as met\n",
        "import torch_xla.distributed.parallel_loader as pl\n",
        "from fairseq import tasks, optim\n",
        "\n",
        "# GLOBAL SCOPE\n",
        "NUM_DEVICES = 8\n",
        "# Set up fairseq dataclasses\n",
        "task = tasks.setup_task(args)\n",
        "task.load_dataset(args.train_subset, epoch=0)\n",
        "criterion = task.build_criterion(args)\n",
        "# This is our initial model.\n",
        "fairseq_model = task.build_model(args)\n",
        "# Let's set the embedding table dimension to a high number, explicitly here:\n",
        "# Distributing the embedding table now it with the \n",
        "distr_model = TransformerWithDistributedEmbeddings(\n",
        "    fairseq_model, EMBEDDING_SIZE, world_size=NUM_DEVICES)\n",
        "distr_model.train(), fairseq_model.train()\n",
        "wrapped_model = xmp.MpModelWrapper(distr_model)\n",
        "\n",
        "\n",
        "def train(index):\n",
        "  device = xm.xla_device()\n",
        "  m = wrapped_model.to(device)\n",
        "  # Let's initialize the table weights.\n",
        "  #   We seed per process so every table inits to a different set of weights.\n",
        "  torch.manual_seed(xm.get_ordinal())\n",
        "  m.init_emb_weights()   \n",
        "  torch.manual_seed(args.seed)\n",
        "  epoch_itr = task.get_batch_iterator(\n",
        "      dataset=task.dataset(args.train_subset),\n",
        "      max_tokens=args.max_tokens,\n",
        "      max_sentences=args.max_sentences,\n",
        "      max_positions=(args.max_source_positions, args.max_target_positions),\n",
        "      ignore_invalid_inputs=True,\n",
        "      required_batch_size_multiple=args.required_batch_size_multiple,\n",
        "      seed=args.seed,\n",
        "      num_shards=NUM_DEVICES,\n",
        "      shard_id=xm.get_ordinal(),\n",
        "      num_workers=args.num_workers,\n",
        "      epoch=0,\n",
        "  )\n",
        "  itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=False, shuffle=False)\n",
        "  para_loader = pl.MpDeviceLoader(itr, device)\n",
        "  # The distributed embedding needs to have its own optimizer, because\n",
        "  #   the embedding table is sharded and we do not want gradient reduction\n",
        "  #   happening across all cores.\n",
        "  # Thus, we create two optimizers, one for the distributed embedding, and \n",
        "  #   another for the remainder of the model. The latter's gradients will be\n",
        "  #   reduced as usual, and we'll call the custom backward on the other one. \n",
        "\n",
        "  model_optimizer = optim.build_optimizer(args, m.non_distr_params())\n",
        "  model_lr_scheduler = optim.lr_scheduler.build_lr_scheduler(\n",
        "      args, model_optimizer)  # learning rate warmup\n",
        "  demb_optimizer = optim.build_optimizer(args, m.embedding.parameters())\n",
        "\n",
        "  running_loss = 0\n",
        "  for step, batch in enumerate(para_loader, 1):\n",
        "    # We will do 100 steps to illustrate the training avoids any OOMs.\n",
        "    if step > 100 or step == len(itr):\n",
        "        break  # drop the last batch\n",
        "    model_optimizer.zero_grad(), demb_optimizer.zero_grad()\n",
        "    demb_optimizer.set_lr(model_optimizer.get_lr())\n",
        "    net_output, fewsamples_fullemb, fullsamples_slicedemb = \\\n",
        "        m(**batch['net_input'])\n",
        "    loss, _ = criterion.compute_loss(m.model, net_output, batch)\n",
        "    loss.backward()  # this only back-propagates up to the embeddings\n",
        "    xm.reduce_gradients(model_optimizer)\n",
        "    model_optimizer.clip_grad_norm(args.clip_norm)\n",
        "    model_optimizer.step()  # update model weights up to the embeddings\n",
        "    # Custom backward to handle distributed embeddings\n",
        "    m.emb_backward(fullsamples_slicedemb, fewsamples_fullemb.grad)\n",
        "    demb_optimizer.clip_grad_norm(args.clip_norm)\n",
        "    demb_optimizer.step()  # update embeddings\n",
        "    # Learning rate warmup\n",
        "    model_lr_scheduler.step_update(step)\n",
        "    # Record loss for reporting later.\n",
        "    running_loss += loss / math.log(2) / batch['ntokens']\n",
        "    if step % args.log_interval:\n",
        "      continue\n",
        "    running_loss = running_loss.item()\n",
        "    update = 'Step {}, loss {:.4f}'.format(step, running_loss / step)\n",
        "    xm.add_step_closure(lambda s: xm.master_print(s, flush=True), args=(update,))\n",
        "  xm.master_print(met.metrics_report())"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_VS5EGD1fmYB",
        "colab_type": "text"
      },
      "source": [
        "Now let's fire up the training, and observe that it doesn't crash w/ an HBM OOM! Note that the first few steps take long because of initial compilations."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6OzYnmCKfpdm",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "xmp.spawn(train, args=(), nprocs=NUM_DEVICES, start_method='fork')"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}